We use the pymc3 probabilistic programming library to fit a simplified SEIR model to the COVID-19 data recorded for Lombardy, Italy by the Protezione Civile and made available at https://github.com/pcm-dpc/COVID-19. Model assumptions are discussed and the quality of the fit model is examined.
The model has the following compartments:
Transitions in the model happen according to the following:
Parameters used above:
Our goal is to fit this model to the data from Lombardy, Italy.
We also have data for total tests administered (tamponi) and total patients hospitalized (totale_ospedalizzati).
For simplicity, we model the observation errors generically as Gaussian noise with constant plus linear scaling sigma.
%cd seir
from itertools import chain, islice
import pickle
import numpy as np
import pandas as pd
import pymc3 as pm
from scipy.interpolate import interp1d
from tqdm import tqdm
from data import lombardia
import seir
import util
import holoviews as hv
hv.notebook_extension('bokeh', logo=False)
The model trace contains samples from the posterior for all our parameters. After discarding the burn-in period and sub-sampling to get greater statistical independence between samples, we can use these parameter sets to generate plausible model configurations. For each model state, instead of a single best-fit trace, we get a distribution of traces. Because probability density is not very intuitive, we instead map each trace to a probability on the cumulative distribution of our samples, then compute the tail probability, i.e. the probability of the true value being farther from the model median.
times = pd.date_range(seir.DATE_OF_SIM_TIME_ZERO, '1 June 2020', freq='1d')
P = np.linspace(0, 1, 1001)[1:-1]
Ptail = 100 - 200 * abs(P - 0.5)
cache_file = 'trace_predictions.npy'
try:
X = np.load(cache_file)
except FileNotFoundError:
with open('trace.pkl', 'rb') as f:
trace = pickle.load(f)
# Wrap in list just as a lazy way to get better progress bar stats
post_tune_samples = list(chain(*(islice(trace.points(chains=[i]), 2_000, None, 4) for i in trace.chains)))
X = []
for ti in tqdm(post_tune_samples):
X.append(seir.run_odeint(times, **{k: ti[k] for k in seir.PARAMS}))
X = np.array(X)
np.save('trace_predictions.npy', X)
X.sort(axis=0)
cum_prob = np.linspace(0, 1, X.shape[0] + 1)[1:]
Xi = interp1d(cum_prob, X, axis=0)(P)
df = [{'date': times, 'P': p, 's': x[0], 'e': x[1], 'i0': x[2], 'i0d': x[3], 'i1': x[4], 'i2': x[5], 'f': x[6],
'fd': x[7], 'r': x[8], 'rd': x[9], 'confirmed cases': x[3] + x[4] + x[5] + x[7] + x[9],
'unconfirmed cases': x[1] + x[2]}
for x, p in zip(Xi, Ptail)]
vlines = (
hv.VLine(pd.datetime.now()).options(color='black', line_width=1, line_dash='dashed') *
hv.VLine(seir.DATE_OF_LOMBARDY_LOCKDOWN).options(color='grey', line_width=1, line_dash='dashed') *
hv.VLine(seir.DATE_OF_SHUTDOWN_OF_NONESSENTIALS).options(color='grey', line_width=1, line_dash='dashed')
)
Our model makes the following distribution of predictions for total confirmed cases, which we observe to be well fit to the confirmed cases in the data. The plots below show the model predictions through the first of June assuming the current policies remain in effect. The bottom plot is identical to the top except that it's y-axis is log-scaled.
plot = (
hv.Contours(df, ['date', 'confirmed cases'], 'P')
.options(cmap='viridis', colorbar=True, show_legend=False, line_width=2, logz=True,
cformatter='%.2g%%', aspect=4, responsive=True)
.redim.range(**{'confirmed cases': (1, None)}) *
hv.Scatter(lombardia, 'data', 'totale_casi').options(alpha=0.6, color='red', size=6) *
vlines
)
plot = (plot + plot.options(logy=True)).cols(1)
plot
Our model predicts the following distribution of unconfirmed cases.
plot = (
hv.Contours(df, ['date', 'unconfirmed cases'], 'P')
.options(cmap='viridis', colorbar=True, show_legend=False, line_width=2, logz=True,
cformatter='%.2g%%', aspect=4, responsive=True)
.redim.range(**{'unconfirmed cases': (1, None)}) *
vlines
)
plot = (plot + plot.options(logy=True)).cols(1)
plot
plot = (
hv.Contours(df, ['date', ('fd', 'deaths')], 'P')
.options(cmap='viridis', colorbar=True, show_legend=False, line_width=2, logz=True,
cformatter='%.2g%%', aspect=4, responsive=True)
.redim.range(deaths=(1, None)) *
hv.Scatter(lombardia, 'data', 'deceduti').options(alpha=0.6, color='red', size=6) *
vlines
)
plot = (plot + plot.options(logy=True)).cols(1)
plot
!conda env export --from-history | grep -v 'prefix:'